"""Shared entities."""
import os
from enum import Enum
from pathlib import Path
from typing import Type

from bigym.bigym_env import BiGymEnv
from bigym.envs.cupboards import (
    DrawerTopOpen,
    DrawerTopClose,
    DrawersAllOpen,
    DrawersAllClose,
    CupboardsCloseAll,
    CupboardsOpenAll,
    WallCupboardOpen,
    WallCupboardClose,
)
from bigym.envs.dishwasher import (
    DishwasherOpen,
    DishwasherClose,
    DishwasherOpenTrays,
    DishwasherCloseTrays,
)
from bigym.envs.dishwasher_cups import (
    DishwasherUnloadCups,
    DishwasherLoadCups,
    DishwasherUnloadCupsLong,
)
from bigym.envs.dishwasher_cutlery import (
    DishwasherUnloadCutlery,
    DishwasherLoadCutlery,
    DishwasherUnloadCutleryLong,
)
from bigym.envs.dishwasher_plates import (
    DishwasherUnloadPlates,
    DishwasherLoadPlates,
    DishwasherUnloadPlatesLong,
)
from bigym.envs.groceries import GroceriesStoreLower, GroceriesStoreUpper
from bigym.envs.mainpulation import FlipCup, FlipCutlery, StackBlocks
from bigym.envs.move_plates import MovePlate, MoveTwoPlates
from bigym.envs.pick_and_place import (
    TakeCups,
    PutCups,
    PickBox,
    StoreBox,
    SaucepanToHob,
    StoreKitchenware,
    ToastSandwich,
    FlipSandwich,
    RemoveSandwich,
)
from bigym.envs.reach_target import ReachTarget, ReachTargetDual, ReachTargetSingle
from demonstrations.const import SAFETENSORS_SUFFIX


class ReplayMode(Enum):
    """Enum controlling joint position mode during demo replay."""

    Absolute = 0
    Delta = 1


REPLAY_MODES: dict[str, ReplayMode] = {
    "Absolute": ReplayMode.Absolute,
    "Delta": ReplayMode.Delta,
}
ENVIRONMENTS: dict[str, Type[BiGymEnv]] = {
    "Reach Target": ReachTarget,
    "Reach Target Single": ReachTargetSingle,
    "Reach Target Dual": ReachTargetDual,
    "Stack Blocks": StackBlocks,
    "Move Plate": MovePlate,
    "Move Two Plates": MoveTwoPlates,
    "Dishwasher Open": DishwasherOpen,
    "Dishwasher Close": DishwasherClose,
    "Dishwasher Open Trays": DishwasherOpenTrays,
    "Dishwasher Close Trays": DishwasherCloseTrays,
    "Unload Plates": DishwasherUnloadPlates,
    "Unload Plates Long": DishwasherUnloadPlatesLong,
    "Load Plates": DishwasherLoadPlates,
    "Unload Cutlery": DishwasherUnloadCutlery,
    "Unload Cutlery Long": DishwasherUnloadCutleryLong,
    "Load Cutlery": DishwasherLoadCutlery,
    "Unload Cups": DishwasherUnloadCups,
    "Unload Cups Long": DishwasherUnloadCupsLong,
    "Load Cups": DishwasherLoadCups,
    "Drawer Top Open": DrawerTopOpen,
    "Drawer Top Close": DrawerTopClose,
    "Drawers All Open": DrawersAllOpen,
    "Drawers All Close": DrawersAllClose,
    "Wall Cupboard Open": WallCupboardOpen,
    "Wall Cupboard  Close": WallCupboardClose,
    "Cupboards Open All": CupboardsOpenAll,
    "Cupboards Close All": CupboardsCloseAll,
    "Take Cups": TakeCups,
    "Put Cups": PutCups,
    "Flip Cup": FlipCup,
    "Flip Cutlery": FlipCutlery,
    "Pick Box": PickBox,
    "Store Box": StoreBox,
    "Saucepan To Hob": SaucepanToHob,
    "Store Kitchenware": StoreKitchenware,
    "Toast Sandwich": ToastSandwich,
    "Flip Sandwich": FlipSandwich,
    "Remove Sandwich": RemoveSandwich,
    "Groceries Store Lower": GroceriesStoreLower,
    "Groceries Store Upper": GroceriesStoreUpper,
}


def get_demos_in_dir(directory: Path) -> list[Path]:
    """Get all demonstrations files in directory."""
    files = os.listdir(directory)
    demos: list[Path] = []
    for file in files:
        demo_path = directory / file
        if demo_path.is_file() and demo_path.suffix == SAFETENSORS_SUFFIX:
            demos.append(Path(demo_path))
    demos = sorted(demos)
    return demos
